/**
 * (C) Copyright IBM Corp. 2010, 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
 */

package com.ibm.bi.dml.hops.rewrite;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.ibm.bi.dml.hops.Hop;
import com.ibm.bi.dml.hops.HopsException;
import com.ibm.bi.dml.hops.IndexingOp;
import com.ibm.bi.dml.hops.LeftIndexingOp;
import com.ibm.bi.dml.hops.LiteralOp;
import com.ibm.bi.dml.parser.Expression.DataType;
import com.ibm.bi.dml.parser.Expression.ValueType;

/**
 * Rule: Indexing vectorization. This rewrite rule set simplifies
 * multiple right / left indexing accesses within a DAG into row/column
 * index accesses, which is beneficial for two reasons: (1) it is an 
 * enabler for later row/column partitioning, and (2) it reduces the number
 * of operations over potentially large data (i.e., prevents unnecessary MR 
 * operations and reduces pressure on the buffer pool due to copy on write
 * on left indexing).
 * 
 */
public class RewriteIndexingVectorization extends HopRewriteRule
{
	
	private static final Log LOG = LogFactory.getLog(RewriteIndexingVectorization.class.getName());
	
	@Override
	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) 
		throws HopsException
	{
		if( roots == null )
			return roots;

		for( Hop h : roots )
			rule_IndexingVectorization( h );
		
		return roots;
	}

	@Override
	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) 
		throws HopsException
	{
		if( root == null )
			return root;
		
		rule_IndexingVectorization( root );
		
		return root;
	}


	/**
	 * 
	 * @param hop
	 * @param descendFirst
	 * @throws HopsException
	 */
	private void rule_IndexingVectorization( Hop hop ) 
		throws HopsException 
	{
		if(hop.getVisited() == Hop.VisitStatus.DONE)
			return;
		
		//recursively process children
		for( int i=0; i<hop.getInput().size(); i++)
		{
			Hop hi = hop.getInput().get(i);
			
			//apply indexing vectorization rewrites
			//MB: disabled right indexing rewrite because (1) piggybacked in MR anyway, (2) usually
			//not too much overhead, and (3) makes literal replacement more difficult
			//vectorizeRightIndexing( hi ); //e.g., multiple rightindexing X[i,1], X[i,3] -> X[i,];
			vectorizeLeftIndexing( hi );  //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,]; 
			
			//process childs recursively after rewrites 
			rule_IndexingVectorization( hi );
		}

		hop.setVisited(Hop.VisitStatus.DONE);
	}

	/**
	 * Note: unnecessary row or column indexing then later removed via
	 * dynamic rewrites
	 * 
	 * @param hop
	 * @throws HopsException 
	 */
	@SuppressWarnings("unused")
	private void vectorizeRightIndexing( Hop hop )
		throws HopsException
	{
		if( hop instanceof IndexingOp ) //right indexing
		{
			IndexingOp ihop0 = (IndexingOp) hop;
			boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
			boolean isSingleCol = ihop0.getColLowerEqualsUpper();
			boolean appliedRow = false;
			
			//search for multiple indexing in same row
			if( isSingleRow && isSingleCol ){
				Hop input = ihop0.getInput().get(0);
				//find candidate set
				//dependence on common subexpression elimination to find equal input / row expression
				ArrayList<Hop> ihops = new ArrayList<Hop>();
				ihops.add(ihop0);
				for( Hop c : input.getParent() ){
					if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
					   && ((IndexingOp) c).getRowLowerEqualsUpper() 
					   && c.getInput().get(1)==ihop0.getInput().get(1) )
					{
						ihops.add( c );
					}
				}
				//apply rewrite if found candidates
				if( ihops.size() > 1 ){
					//new row indexing operator
					IndexingOp newRix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, 
							            ihop0.getInput().get(1), ihop0.getInput().get(1), new LiteralOp(1), 
							            HopRewriteUtils.createValueHop(input, false), true, false); 
					HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newRix.refreshSizeInformation();
					//rewire current operator and all candidates
					for( Hop c : ihops ) {
						HopRewriteUtils.removeChildReference(c, input); //input data
						HopRewriteUtils.addChildReference(c, newRix, 0);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(1),1); //row lower expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 1);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2),2); //row upper expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
						c.refreshSizeInformation();
					}
					
					appliedRow = true;
					LOG.debug("Applied vectorizeRightIndexingRow");
				}
			}
			
			//search for multiple indexing in same col
			if( isSingleRow && isSingleCol && !appliedRow ){
				Hop input = ihop0.getInput().get(0);
				//find candidate set
				//dependence on common subexpression elimination to find equal input / row expression
				ArrayList<Hop> ihops = new ArrayList<Hop>();
				ihops.add(ihop0);
				for( Hop c : input.getParent() ){
					if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
					   && ((IndexingOp) c).getColLowerEqualsUpper() 
					   && c.getInput().get(3)==ihop0.getInput().get(3) )
					{
						ihops.add( c );
					}
				}
				//apply rewrite if found candidates
				if( ihops.size() > 1 ){
					//new row indexing operator
					IndexingOp newRix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, 
							         new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
				                    ihop0.getInput().get(3), ihop0.getInput().get(3), false, true); 
					HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newRix.refreshSizeInformation();
					//rewire current operator and all candidates
					for( Hop c : ihops ) {
						HopRewriteUtils.removeChildReference(c, input); //input data
						HopRewriteUtils.addChildReference(c, newRix, 0);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3),3); //col lower expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4),4); //col upper expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
						c.refreshSizeInformation();
					}

					LOG.debug("Applied vectorizeRightIndexingCol");
				}
			}
		}
	}
	
	/**
	 * 
	 * @param hop
	 * @throws HopsException
	 */
	@SuppressWarnings("unchecked")
	private void vectorizeLeftIndexing( Hop hop )
		throws HopsException
	{
		if( hop instanceof LeftIndexingOp ) //left indexing
		{
			LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
			boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
			boolean isSingleCol = ihop0.getColLowerEqualsUpper();
			boolean appliedRow = false;
			
			if( isSingleRow && isSingleCol )
			{
				//collect simple chains (w/o multiple consumers) of left indexing ops
				ArrayList<Hop> ihops = new ArrayList<Hop>();
				ihops.add(ihop0);
				Hop current = ihop0;
				while( current.getInput().get(0) instanceof LeftIndexingOp ) {
					LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
					if(    tmp.getParent().size()>1  //multiple consumers, i.e., not a simple chain
						|| !((LeftIndexingOp) tmp).getRowLowerEqualsUpper() //row merge not applicable
						|| tmp.getInput().get(2) != ihop0.getInput().get(2) //not the same row
						|| tmp.getInput().get(0).getDim2() <= 1 ) //target is single column or unknown 
					{
						break;
					}
					ihops.add( tmp );
					current = tmp;
				}
				
				//apply rewrite if found candidates
				if( ihops.size() > 1 ){
					Hop input = current.getInput().get(0);
					Hop rowExpr = ihop0.getInput().get(2); //keep before reset
					
					//new row indexing operator
					IndexingOp newRix = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, input, 
							            rowExpr, rowExpr, new LiteralOp(1), 
							            HopRewriteUtils.createValueHop(input, false), true, false); 
					HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newRix.refreshSizeInformation();
					
					//rewrite bottom left indexing operator
					HopRewriteUtils.removeChildReference(current, input); //input data
					HopRewriteUtils.addChildReference(current, newRix, 0);
					
					//reset row index all candidates
					for( Hop c : ihops ) {
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2), 2); //row lower expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3), 3); //row upper expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
						c.refreshSizeInformation();
					}
					
					//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
					//(note: it's important to clone the parent list before creating newLix on top of ihop0)
					ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
					ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
					for( Hop parent : ihop0parents ) {
						int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
						HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
						ihop0parentsPos.add(posp);
					}
					
					LeftIndexingOp newLix = new LeftIndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, input, ihop0, 
													rowExpr, rowExpr, new LiteralOp(1), 
													HopRewriteUtils.createValueHop(input, false), true, false); 
					HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newLix.refreshSizeInformation();
					
					for( int i=0; i<ihop0parentsPos.size(); i++ ) {
						Hop parent = ihop0parents.get(i);
						int posp = ihop0parentsPos.get(i);
						HopRewriteUtils.addChildReference(parent, newLix, posp);
					}
					
					appliedRow = true;
					LOG.debug("Applied vectorizeLeftIndexingRow");
				}
			}
			
			if( isSingleRow && isSingleCol && !appliedRow )
			{
				//collect simple chains (w/o multiple consumers) of left indexing ops
				ArrayList<Hop> ihops = new ArrayList<Hop>();
				ihops.add(ihop0);
				Hop current = ihop0;
				while( current.getInput().get(0) instanceof LeftIndexingOp ) {
					LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
					if(    tmp.getParent().size()>1  //multiple consumers, i.e., not a simple chain
						|| !((LeftIndexingOp) tmp).getColLowerEqualsUpper() //row merge not applicable
						|| tmp.getInput().get(4) != ihop0.getInput().get(4)  //not the same col
						|| tmp.getInput().get(0).getDim1() <= 1 )  //target is single row or unknown
					{
						break;
					}
					ihops.add( tmp );
					current = tmp;
				}
				//apply rewrite if found candidates
				if( ihops.size() > 1 ){
					Hop input = current.getInput().get(0);
					Hop colExpr = ihop0.getInput().get(4); //keep before reset
					
					//new row indexing operator
					IndexingOp newRix = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, input, 
							        new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),            
									colExpr, colExpr, false, true); 
					HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newRix.refreshSizeInformation();
					
					//rewrite bottom left indexing operator
					HopRewriteUtils.removeChildReference(current, input); //input data
					HopRewriteUtils.addChildReference(current, newRix, 0);
					
					//reset row index all candidates
					for( Hop c : ihops ) {
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4), 4); //col lower expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
						HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(5), 5); //col upper expr
						HopRewriteUtils.addChildReference(c, new LiteralOp(1), 5);
						c.refreshSizeInformation();
					}
					
					//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
					//(note: it's important to clone the parent list before creating newLix on top of ihop0)
					ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
					ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
					for( Hop parent : ihop0parents ) {
						int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
						HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
						ihop0parentsPos.add(posp);
					}
					
					LeftIndexingOp newLix = new LeftIndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, input, ihop0, 
							                        new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), 
													colExpr, colExpr, false, true); 
					HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
					newLix.refreshSizeInformation();
					
					for( int i=0; i<ihop0parentsPos.size(); i++ ) {
						Hop parent = ihop0parents.get(i);
						int posp = ihop0parentsPos.get(i);
						HopRewriteUtils.addChildReference(parent, newLix, posp);
					}
					
					appliedRow = true;
					LOG.debug("Applied vectorizeLeftIndexingCol");
				}
			}
		}
	}
}
